From e8edbe7595aa68fbff936abdea32428da033807a Mon Sep 17 00:00:00 2001 From: Avinder Bahra Date: Mon, 23 Sep 2024 14:19:04 +0100 Subject: [PATCH] experiments for narrow (#494) --- .../zio/dynamodb/TypeSafeApiMappingSpec.scala | 4 +- .../zio/dynamodb/TypeSafeApiNarrowSpec.scala | 132 ++++++++++++++++++ .../scala/zio/dynamodb/DynamoDBQuery.scala | 51 ++++++- 3 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiNarrowSpec.scala diff --git a/dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiMappingSpec.scala b/dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiMappingSpec.scala index 8b942d208..add35cc8e 100644 --- a/dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiMappingSpec.scala +++ b/dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiMappingSpec.scala @@ -71,7 +71,7 @@ object TypeSafeApiMappingSpec extends DynamoDBLocalSpec { _ <- put[InvoiceWithDiscriminatorName](invoiceTable, InvoiceWithDiscriminatorName.Unpaid("1")).execute invoice <- // invoice is of type InvoiceWithDiscriminatorName get(invoiceTable)( - ((InvoiceWithDiscriminatorName.unpaid) >>> (InvoiceWithDiscriminatorName.Unpaid.id)).partitionKey === "1" + (InvoiceWithDiscriminatorName.unpaid >>> InvoiceWithDiscriminatorName.Unpaid.id).partitionKey === "1" ).execute.absolve } yield assertTrue(invoice == InvoiceWithDiscriminatorName.Unpaid("1")) } @@ -81,7 +81,7 @@ object TypeSafeApiMappingSpec extends DynamoDBLocalSpec { for { _ <- put[InvoiceWithDiscriminatorName](invoiceTable, InvoiceWithDiscriminatorName.Unpaid("1")).execute invoice <- get(invoiceTable)( // invoice is of type InvoiceWithDiscriminatorName.Unpaid - (InvoiceWithDiscriminatorName.Unpaid.id).partitionKey === "1" + InvoiceWithDiscriminatorName.Unpaid.id.partitionKey === "1" ).execute.absolve } yield assertTrue(invoice == InvoiceWithDiscriminatorName.Unpaid("1")) } diff --git a/dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiNarrowSpec.scala b/dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiNarrowSpec.scala new file mode 100644 index 000000000..6a055c989 --- /dev/null +++ b/dynamodb/src/it/scala/zio/dynamodb/TypeSafeApiNarrowSpec.scala @@ -0,0 +1,132 @@ +package zio.dynamodb + +import zio.dynamodb.DynamoDBQuery.{ getItem, put } +import zio.Scope +import zio.test.Spec +import zio.test.Assertion.fails +import zio.test.{ assert, assertTrue } +import zio.test.TestEnvironment +import zio.test.Assertion.{ equalTo, isLeft, isRight } +import zio.test.TestAspect +import zio.schema.Schema +import zio.schema.DeriveSchema +import zio.schema.annotation.discriminatorName +import zio.dynamodb.DynamoDBQuery.getWithNarrow +import zio.dynamodb.DynamoDBError.ItemError + +object TypeSafeApiNarrowSpec extends DynamoDBLocalSpec { + + object dynamo { + @discriminatorName("invoiceType") + sealed trait Invoice { + def id: String + } + object Invoice { + final case class Unrelated(id: Int) + object Unrelated { + implicit val schema: Schema.CaseClass1[Int, Unrelated] = DeriveSchema.gen[Unrelated] + val id = ProjectionExpression.accessors[Unrelated] + } + final case class Unpaid(id: String) extends Invoice + object Unpaid { + implicit val schema: Schema.CaseClass1[String, Unpaid] = DeriveSchema.gen[Unpaid] + val id = ProjectionExpression.accessors[Unpaid] + } + final case class Paid(id: String, amount: Int) extends Invoice + object Paid { + implicit val schema: Schema.CaseClass2[String, Int, Paid] = DeriveSchema.gen[Paid] + val (id, amount) = ProjectionExpression.accessors[Paid] + } + implicit val schema: Schema.Enum2[Unpaid, Paid, Invoice] = + DeriveSchema.gen[Invoice] + val (unpaid, paid) = ProjectionExpression.accessors[Invoice] + } + + } + + override def spec: Spec[Environment with TestEnvironment with Scope, Any] = + suite("TypeSafeApiNarrowSpec")( + topLevelSumTypeNarrowSuite, + narrowSuite + ) @@ TestAspect.nondeterministic + + val topLevelSumTypeNarrowSuite = suite("for top level Invoice sum type with @discriminatorName annotation")( + test("getWithNarrow succeeds in narrowing an Unpaid Invoice instance to Unpaid") { + withSingleIdKeyTable { invoiceTable => + val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Unpaid] = + dynamo.Invoice.Unpaid.id.partitionKey === "1" + for { + _ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Unpaid("1")).execute + item <- getItem(invoiceTable, PrimaryKey("id" -> "1")).execute + + unpaid <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoiceTable)(keyCond).execute.absolve + } yield { + val unpaid2: dynamo.Invoice.Unpaid = unpaid + val ensureDiscriminatorPresent = item == Some(Item("id" -> "1", "invoiceType" -> "Unpaid")) + assertTrue(unpaid2 == dynamo.Invoice.Unpaid("1") && ensureDiscriminatorPresent) + } + } + }, + test("getWithNarrow succeeds in narrowing an Paid Invoice instance to Paid") { + withSingleIdKeyTable { invoiceTable => + val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Paid] = + dynamo.Invoice.Paid.id.partitionKey === "1" + for { + _ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Paid("1", 42)).execute + item <- getItem(invoiceTable, PrimaryKey("id" -> "1")).execute + + paid <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Paid](invoiceTable)(keyCond).execute.absolve + } yield { + val paid2: dynamo.Invoice.Paid = paid + val ensureDiscriminatorPresent = item == Some(Item("id" -> "1", "invoiceType" -> "Paid", "amount" -> 42)) + assertTrue(paid2 == dynamo.Invoice.Paid("1", 42) && ensureDiscriminatorPresent) + } + } + }, + test("getWithNarrow fails in narrowing an Unpaid Invoice instance to Paid") { + withSingleIdKeyTable { invoiceTable => + val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Paid] = + dynamo.Invoice.Paid.id.partitionKey === "1" + for { + _ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Unpaid("1")).execute + exit <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Paid](invoiceTable)(keyCond).execute.absolve.exit + } yield assert(exit)( + fails(equalTo(ItemError.DecodingError("failed to narrow - found type Unpaid but expected type Paid"))) + ) + } + }, + test("getWithNarrow fails in narrowing an Paid Invoice instance to Unpaid") { + withSingleIdKeyTable { invoiceTable => + val keyCond: KeyConditionExpr.PartitionKeyEquals[dynamo.Invoice.Unpaid] = + dynamo.Invoice.Unpaid.id.partitionKey === "1" + for { + _ <- put[dynamo.Invoice](invoiceTable, dynamo.Invoice.Paid("1", 42)).execute + exit <- getWithNarrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoiceTable)(keyCond).execute.absolve.exit + } yield assert(exit)( + fails(equalTo(ItemError.DecodingError("failed to narrow - found type Paid but expected type Unpaid"))) + ) + } + } + ) + + val narrowSuite = suite("narrow suite")( + test("narrow Paid instance to Paid for success and failure") { + val invoice: dynamo.Invoice = dynamo.Invoice.Paid("1", 1) + val valid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Paid](invoice) + val invalid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoice) + + assert(valid)(isRight) && assert(invalid)( + isLeft(equalTo("failed to narrow - found type Paid but expected type Unpaid")) + ) + }, + test("narrow Unpaid instance to Unpaid for success and failure") { + val invoice: dynamo.Invoice = dynamo.Invoice.Unpaid("1") + val valid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Unpaid](invoice) + val invalid = DynamoDBQuery.narrow[dynamo.Invoice, dynamo.Invoice.Paid](invoice) + + assert(valid)(isRight) && assert(invalid)( + isLeft(equalTo("failed to narrow - found type Unpaid but expected type Paid")) + ) + } + ) +} diff --git a/dynamodb/src/main/scala/zio/dynamodb/DynamoDBQuery.scala b/dynamodb/src/main/scala/zio/dynamodb/DynamoDBQuery.scala index 34402d4bd..2563434a0 100644 --- a/dynamodb/src/main/scala/zio/dynamodb/DynamoDBQuery.scala +++ b/dynamodb/src/main/scala/zio/dynamodb/DynamoDBQuery.scala @@ -480,6 +480,56 @@ object DynamoDBQuery { ): DynamoDBQuery[From, Either[ItemError, From]] = get(tableName, primaryKeyExpr.asAttrMap, ProjectionExpression.projectionsFromSchema[From]) + /** + * It is common practice to save top level sum types to DynamoDB and often we want to retrieve them back as the subtype. + * `getWithNarrow` does a `get` with a safe narrow operation from type `From` to `To`. + * If the narrow fails it returns a Decoding error with details of the cast failure in the message. + * Requires implicit schemas in scope which ensure that `From` is an enum (sealed trait) and `To` is a record (case class) subtype. + */ + def getWithNarrow[From: Schema.Enum, To <: From: Schema.Record](tableName: String)( + primaryKeyExpr: KeyConditionExpr.PrimaryKeyExpr[To] + ): DynamoDBQuery[From, Either[ItemError, To]] = { + + def getWithNarrowedKeyCondExpr[From: Schema.Enum, To <: From](tableName: String)( + primaryKeyExpr: KeyConditionExpr.PrimaryKeyExpr[To] + ): DynamoDBQuery[From, Either[ItemError, From]] = + get(tableName, primaryKeyExpr.asAttrMap, ProjectionExpression.projectionsFromSchema[From]) + + getWithNarrowedKeyCondExpr[From, To](tableName)(primaryKeyExpr).map { + case Right(found) => + narrow[From, To](found).left.map(DynamoDBError.ItemError.DecodingError.apply) + + case Left(error) => Left(error) + } + } + + // Safely narrows `a: From` to subtype type `To` and requires that there are implicit schemas in scope which + // ensure that `From` is an enum (sealed trait) and `To` is a record (case class) subtype. + private[dynamodb] def narrow[From: Schema.Enum, To <: From: Schema.Record]( + a: From + ): Either[String, To] = { + val fromEnumSchema: Schema.Enum[From] = implicitly[Schema.Enum[From]] + val toSchema: Schema.Record[To] = implicitly[Schema.Record[To]] + val o: Option[Schema.Case[From, _]] = fromEnumSchema.caseOf(a) + + o match { + case Some(c @ Schema.Case(_, Schema.Lazy(s), _, _, _, _)) => + s() == toSchema match { + case true => Right(a.asInstanceOf[To]) + case _ => + Left(s"failed to narrow - found type ${c.id} but expected type ${toSchema.id.name}") + } + case Some(c) => + c.schema == toSchema match { + case true => Right(a.asInstanceOf[To]) + case _ => Left(s"failed to narrow - found type ${c.id} but expected type ${toSchema.id.name}") + } + case None => + // this should never happen as we have a type level proof + Left(s"failed to narrow - argument is not a subtype of ${fromEnumSchema.id.name}") + } + } + private def get[A: Schema]( tableName: String, key: PrimaryKey, @@ -546,7 +596,6 @@ object DynamoDBQuery { /** * when executed will return a Tuple of {{{Either[String,(Chunk[A], LastEvaluatedKey)]}}} */ - def scanSome[A: Schema]( tableName: String, limit: Int