Skip to content

Commit

Permalink
experiments for narrow (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
googley42 authored Sep 23, 2024
1 parent 8963a30 commit e8edbe7
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object TypeSafeApiMappingSpec extends DynamoDBLocalSpec {
_ <- put[InvoiceWithDiscriminatorName](invoiceTable, InvoiceWithDiscriminatorName.Unpaid("1")).execute
invoice <- // invoice is of type InvoiceWithDiscriminatorName
get(invoiceTable)(
((InvoiceWithDiscriminatorName.unpaid) >>> (InvoiceWithDiscriminatorName.Unpaid.id)).partitionKey === "1"
(InvoiceWithDiscriminatorName.unpaid >>> InvoiceWithDiscriminatorName.Unpaid.id).partitionKey === "1"
).execute.absolve
} yield assertTrue(invoice == InvoiceWithDiscriminatorName.Unpaid("1"))
}
Expand All @@ -81,7 +81,7 @@ object TypeSafeApiMappingSpec extends DynamoDBLocalSpec {
for {
_ <- put[InvoiceWithDiscriminatorName](invoiceTable, InvoiceWithDiscriminatorName.Unpaid("1")).execute
invoice <- get(invoiceTable)( // invoice is of type InvoiceWithDiscriminatorName.Unpaid
(InvoiceWithDiscriminatorName.Unpaid.id).partitionKey === "1"
InvoiceWithDiscriminatorName.Unpaid.id.partitionKey === "1"
).execute.absolve
} yield assertTrue(invoice == InvoiceWithDiscriminatorName.Unpaid("1"))
}
Expand Down
132 changes: 132 additions & 0 deletions dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiNarrowSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package zio.dynamodb

import zio.dynamodb.DynamoDBQuery.{ getItem, put }
import zio.Scope
import zio.test.Spec
import zio.test.Assertion.fails
import zio.test.{ assert, assertTrue }
import zio.test.TestEnvironment
import zio.test.Assertion.{ equalTo, isLeft, isRight }
import zio.test.TestAspect
import zio.schema.Schema
import zio.schema.DeriveSchema
import zio.schema.annotation.discriminatorName
import zio.dynamodb.DynamoDBQuery.getWithNarrow
import zio.dynamodb.DynamoDBError.ItemError

object TypeSafeApiNarrowSpec extends DynamoDBLocalSpec {

object dynamo {
@discriminatorName("invoiceType")
sealed trait Invoice {
def id: String
}
object Invoice {
final case class Unrelated(id: Int)
object Unrelated {
implicit val schema: Schema.CaseClass1[Int, Unrelated] = DeriveSchema.gen[Unrelated]
val id = ProjectionExpression.accessors[Unrelated]
}
final case class Unpaid(id: String) extends Invoice
object Unpaid {
implicit val schema: Schema.CaseClass1[String, Unpaid] = DeriveSchema.gen[Unpaid]
val id = ProjectionExpression.accessors[Unpaid]
}
final case class Paid(id: String, amount: Int) extends Invoice
object Paid {
implicit val schema: Schema.CaseClass2[String, Int, Paid] = DeriveSchema.gen[Paid]
val (id, amount) = ProjectionExpression.accessors[Paid]
}
implicit val schema: Schema.Enum2[Unpaid, Paid, Invoice] =
DeriveSchema.gen[Invoice]
val (unpaid, paid) = ProjectionExpression.accessors[Invoice]
}

}

override def spec: Spec[Environment with TestEnvironment with Scope, Any] =
suite("TypeSafeApiNarrowSpec")(
topLevelSumTypeNarrowSuite,
narrowSuite
) @@ TestAspect.nondeterministic

val topLevelSumTypeNarrowSuite = suite("for top level Invoice sum type with @discriminatorName annotation")(
test("getWithNarrow succeeds in narrowing an Unpaid Invoice instance to Unpaid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Unpaid] =
dynamo.Invoice.Unpaid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Unpaid("1")).execute
item <- getItem(invoiceTable, PrimaryKey("id" -> "1")).execute

unpaid <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoiceTable)(keyCond).execute.absolve
} yield {
val unpaid2: dynamo.Invoice.Unpaid = unpaid
val ensureDiscriminatorPresent = item == Some(Item("id" -> "1", "invoiceType" -> "Unpaid"))
assertTrue(unpaid2 == dynamo.Invoice.Unpaid("1") && ensureDiscriminatorPresent)
}
}
},
test("getWithNarrow succeeds in narrowing an Paid Invoice instance to Paid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Paid] =
dynamo.Invoice.Paid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Paid("1", 42)).execute
item <- getItem(invoiceTable, PrimaryKey("id" -> "1")).execute

paid <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Paid](invoiceTable)(keyCond).execute.absolve
} yield {
val paid2: dynamo.Invoice.Paid = paid
val ensureDiscriminatorPresent = item == Some(Item("id" -> "1", "invoiceType" -> "Paid", "amount" -> 42))
assertTrue(paid2 == dynamo.Invoice.Paid("1", 42) && ensureDiscriminatorPresent)
}
}
},
test("getWithNarrow fails in narrowing an Unpaid Invoice instance to Paid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Paid] =
dynamo.Invoice.Paid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Unpaid("1")).execute
exit <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Paid](invoiceTable)(keyCond).execute.absolve.exit
} yield assert(exit)(
fails(equalTo(ItemError.DecodingError("failed to narrow - found type Unpaid but expected type Paid")))
)
}
},
test("getWithNarrow fails in narrowing an Paid Invoice instance to Unpaid") {
withSingleIdKeyTable { invoiceTable =>
val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Unpaid] =
dynamo.Invoice.Unpaid.id.partitionKey === "1"
for {
_ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Paid("1", 42)).execute
exit <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoiceTable)(keyCond).execute.absolve.exit
} yield assert(exit)(
fails(equalTo(ItemError.DecodingError("failed to narrow - found type Paid but expected type Unpaid")))
)
}
}
)

val narrowSuite = suite("narrow suite")(
test("narrow Paid instance to Paid for success and failure") {
val invoice: dynamo.Invoice = dynamo.Invoice.Paid("1", 1)
val valid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Paid](invoice)
val invalid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoice)

assert(valid)(isRight) && assert(invalid)(
isLeft(equalTo("failed to narrow - found type Paid but expected type Unpaid"))
)
},
test("narrow Unpaid instance to Unpaid for success and failure") {
val invoice: dynamo.Invoice = dynamo.Invoice.Unpaid("1")
val valid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoice)
val invalid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Paid](invoice)

assert(valid)(isRight) && assert(invalid)(
isLeft(equalTo("failed to narrow - found type Unpaid but expected type Paid"))
)
}
)
}
51 changes: 50 additions & 1 deletion dynamodb/src/main/scala/zio/dynamodb/DynamoDBQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,56 @@ object DynamoDBQuery {
): DynamoDBQuery[From, Either[ItemError, From]] =
get(tableName, primaryKeyExpr.asAttrMap, ProjectionExpression.projectionsFromSchema[From])

/**
* It is common practice to save top level sum types to DynamoDB and often we want to retrieve them back as the subtype.
* `getWithNarrow` does a `get` with a safe narrow operation from type `From` to `To`.
* If the narrow fails it returns a Decoding error with details of the cast failure in the message.
* Requires implicit schemas in scope which ensure that `From` is an enum (sealed trait) and `To` is a record (case class) subtype.
*/
def getWithNarrow[From: Schema.Enum, To <: From: Schema.Record](tableName: String)(
primaryKeyExpr: KeyConditionExpr.PrimaryKeyExpr[To]
): DynamoDBQuery[From, Either[ItemError, To]] = {

def getWithNarrowedKeyCondExpr[From: Schema.Enum, To <: From](tableName: String)(
primaryKeyExpr: KeyConditionExpr.PrimaryKeyExpr[To]
): DynamoDBQuery[From, Either[ItemError, From]] =
get(tableName, primaryKeyExpr.asAttrMap, ProjectionExpression.projectionsFromSchema[From])

getWithNarrowedKeyCondExpr[From, To](tableName)(primaryKeyExpr).map {
case Right(found) =>
narrow[From, To](found).left.map(DynamoDBError.ItemError.DecodingError.apply)

case Left(error) => Left(error)
}
}

// Safely narrows `a: From` to subtype type `To` and requires that there are implicit schemas in scope which
// ensure that `From` is an enum (sealed trait) and `To` is a record (case class) subtype.
private[dynamodb] def narrow[From: Schema.Enum, To <: From: Schema.Record](
a: From
): Either[String, To] = {
val fromEnumSchema: Schema.Enum[From] = implicitly[Schema.Enum[From]]
val toSchema: Schema.Record[To] = implicitly[Schema.Record[To]]
val o: Option[Schema.Case[From, _]] = fromEnumSchema.caseOf(a)

o match {
case Some(c @ Schema.Case(_, Schema.Lazy(s), _, _, _, _)) =>
s() == toSchema match {
case true => Right(a.asInstanceOf[To])
case _ =>
Left(s"failed to narrow - found type ${c.id} but expected type ${toSchema.id.name}")
}
case Some(c) =>
c.schema == toSchema match {
case true => Right(a.asInstanceOf[To])
case _ => Left(s"failed to narrow - found type ${c.id} but expected type ${toSchema.id.name}")
}
case None =>
// this should never happen as we have a type level proof
Left(s"failed to narrow - argument is not a subtype of ${fromEnumSchema.id.name}")
}
}

private def get[A: Schema](
tableName: String,
key: PrimaryKey,
Expand Down Expand Up @@ -546,7 +596,6 @@ object DynamoDBQuery {
/**
* when executed will return a Tuple of {{{Either[String,(Chunk[A], LastEvaluatedKey)]}}}
*/

def scanSome[A: Schema](
tableName: String,
limit: Int
Expand Down

0 comments on commit e8edbe7

Please sign in to comment.