From 94bf29de56703337143fc56e2e0166d9ec4fcab0 Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Wed, 7 Aug 2024 12:09:55 +0200 Subject: [PATCH] Non empty collection schemas (#717) --- .../scala-2/zio/schema/DynamicValueGen.scala | 2 + .../scala/zio/schema/codec/JsonCodec.scala | 43 +++++---- .../zio/schema/optics/ZioOpticsBuilder.scala | 42 ++++++++- .../src/main/scala/zio/schema/Schema.scala | 89 ++++++++++++++++++- 4 files changed, 150 insertions(+), 26 deletions(-) diff --git a/tests/shared/src/test/scala-2/zio/schema/DynamicValueGen.scala b/tests/shared/src/test/scala-2/zio/schema/DynamicValueGen.scala index d5a5e73d4..6c1a2c7da 100644 --- a/tests/shared/src/test/scala-2/zio/schema/DynamicValueGen.scala +++ b/tests/shared/src/test/scala-2/zio/schema/DynamicValueGen.scala @@ -74,7 +74,9 @@ object DynamicValueGen { case Schema.Enum22(_, case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22, _) => anyDynamicValueOfEnum(Chunk(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22)) case Schema.EnumN(_, cases, _) => anyDynamicValueOfEnum(Chunk.fromIterable(cases.toSeq)) case Schema.Sequence(schema, _, _, _, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_)) + case Schema.NonEmptySequence(schema, _, _, _, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_)) case Schema.Map(ks, vs, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_)) + case Schema.NonEmptyMap(ks, vs, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_)) case Schema.Set(schema, _) => Gen.setOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.SetValue(_)) case Schema.Optional(schema, _) => Gen.oneOf(anyDynamicSomeValueOfSchema(schema), Gen.const(DynamicValue.NoneValue)) case Schema.Tuple2(left, right, _) => anyDynamicTupleValue(left, right) diff --git a/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala b/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala index 23a5f5635..928a33a99 100644 --- a/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala +++ b/zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala @@ -10,13 +10,8 @@ import zio.json.JsonCodec._ import zio.json.JsonDecoder.{ JsonError, UnsafeJson } import zio.json.ast.Json import zio.json.internal.{ Lexer, RecordingReader, RetractReader, StringMatrix, WithRecordingReader, Write } -import zio.json.{ - JsonCodec => ZJsonCodec, - JsonDecoder => ZJsonDecoder, - JsonEncoder => ZJsonEncoder, - JsonFieldDecoder, - JsonFieldEncoder -} +import zio.json.{JsonCodec => ZJsonCodec, JsonDecoder => ZJsonDecoder, JsonEncoder => ZJsonEncoder, JsonFieldDecoder, JsonFieldEncoder} +import zio.prelude.NonEmptyMap import zio.schema._ import zio.schema.annotation._ import zio.schema.codec.DecodeError.ReadError @@ -182,9 +177,11 @@ object JsonCodec { //scalafmt: { maxColumn = 400, optIn.configStyleArguments = false } private[codec] def schemaEncoder[A](schema: Schema[A], cfg: Config, discriminatorTuple: DiscriminatorTuple = Chunk.empty): ZJsonEncoder[A] = schema match { - case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder - case Schema.Sequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g) - case Schema.Map(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg) + case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder + case Schema.Sequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g) + case Schema.NonEmptySequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g) + case Schema.Map(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg) + case Schema.NonEmptyMap(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg).contramap[NonEmptyMap[Any, Any]](_.toMap.asInstanceOf[Map[Any, Any]]).asInstanceOf[ZJsonEncoder[A]] case Schema.Set(s, _) => ZJsonEncoder.chunk(schemaEncoder(s, cfg, discriminatorTuple)).contramap(m => Chunk.fromIterable(m)) case Schema.Transform(c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg) @@ -544,18 +541,20 @@ object JsonCodec { //scalafmt: { maxColumn = 400, optIn.configStyleArguments = false } private[codec] def schemaDecoder[A](schema: Schema[A], discriminator: Int = -1): ZJsonDecoder[A] = schema match { - case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder - case Schema.Optional(codec, _) => option(schemaDecoder(codec, discriminator)) - case Schema.Tuple2(left, right, _) => ZJsonDecoder.tuple2(schemaDecoder(left, -1), schemaDecoder(right, -1)) - case Schema.Transform(c, f, _, a, _) => schemaDecoder(a.foldLeft(c)((s, a) => s.annotate(a)), discriminator).mapOrFail(f) - case Schema.Sequence(codec, f, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(f) - case Schema.Map(ks, vs, _) => mapDecoder(ks, vs) - case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s, -1)).map(entries => entries.toSet) - case Schema.Fail(message, _) => failDecoder(message) - case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk, schema.annotations.contains(rejectExtraFields())) - case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left, -1), schemaDecoder(right, -1)) - case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s) - case l @ Schema.Lazy(_) => schemaDecoder(l.schema, discriminator) + case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder + case Schema.Optional(codec, _) => option(schemaDecoder(codec, discriminator)) + case Schema.Tuple2(left, right, _) => ZJsonDecoder.tuple2(schemaDecoder(left, -1), schemaDecoder(right, -1)) + case Schema.Transform(c, f, _, a, _) => schemaDecoder(a.foldLeft(c)((s, a) => s.annotate(a)), discriminator).mapOrFail(f) + case Schema.Sequence(codec, f, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(f) + case s @ Schema.NonEmptySequence(codec, _, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(s.fromChunk) + case Schema.Map(ks, vs, _) => mapDecoder(ks, vs) + case Schema.NonEmptyMap(ks, vs, _) => mapDecoder(ks, vs).mapOrFail(m => NonEmptyMap.fromMapOption(m).toRight("NonEmptyMap expected")) + case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s, -1)).map(entries => entries.toSet) + case Schema.Fail(message, _) => failDecoder(message) + case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk, schema.annotations.contains(rejectExtraFields())) + case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left, -1), schemaDecoder(right, -1)) + case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s) + case l @ Schema.Lazy(_) => schemaDecoder(l.schema, discriminator) //case Schema.Meta(_, _) => astDecoder case s @ Schema.CaseClass0(_, _, _) => caseClass0Decoder(discriminator, s) case s @ Schema.CaseClass1(_, _, _, _) => caseClass1Decoder(discriminator, s) diff --git a/zio-schema-optics/shared/src/main/scala/zio/schema/optics/ZioOpticsBuilder.scala b/zio-schema-optics/shared/src/main/scala/zio/schema/optics/ZioOpticsBuilder.scala index 6da76d0d8..73bf1b27f 100644 --- a/zio-schema-optics/shared/src/main/scala/zio/schema/optics/ZioOpticsBuilder.scala +++ b/zio-schema-optics/shared/src/main/scala/zio/schema/optics/ZioOpticsBuilder.scala @@ -3,6 +3,7 @@ package zio.schema.optics import scala.collection.immutable.ListMap import zio.optics._ +import zio.prelude.NonEmptyMap import zio.schema._ import zio.{ Chunk, ChunkBuilder } @@ -44,14 +45,24 @@ object ZioOpticsBuilder extends AccessorBuilder { collection match { case seq @ Schema.Sequence(_, _, _, _, _) => ZTraversal[S, S, A, A]( - ZioOpticsBuilder.makeSeqTraversalGet(seq), + ZioOpticsBuilder.makeSeqTraversalGet(seq.toChunk), ZioOpticsBuilder.makeSeqTraversalSet(seq) ) + case seq @ Schema.NonEmptySequence(_, _, _, _, _) => + ZTraversal[S, S, A, A]( + ZioOpticsBuilder.makeSeqTraversalGet(seq.toChunk), + ZioOpticsBuilder.makeNonEmptySeqTraversalSet(seq) + ) case Schema.Map(_: Schema[k], _: Schema[v], _) => ZTraversal( ZioOpticsBuilder.makeMapTraversalGet[k, v], ZioOpticsBuilder.makeMapTraversalSet[k, v] ) + case Schema.NonEmptyMap(_: Schema[k], _: Schema[v], _) => + ZTraversal( + ZioOpticsBuilder.makeMapTraversalGet[k, v], + ZioOpticsBuilder.makeNonEmptyMapTraversalSet[k, v] + ) case Schema.Set(_, _) => ZTraversal( ZioOpticsBuilder.makeSetTraversalGet[A], @@ -103,9 +114,9 @@ object ZioOpticsBuilder extends AccessorBuilder { } private[optics] def makeSeqTraversalGet[S, A]( - collection: Schema.Sequence[S, A, _] + toChunk: S => Chunk[A] ): S => Either[(OpticFailure, S), Chunk[A]] = { (whole: S) => - Right(collection.toChunk(whole)) + Right(toChunk(whole)) } private[optics] def makeSeqTraversalSet[S, A]( @@ -123,15 +134,40 @@ object ZioOpticsBuilder extends AccessorBuilder { } Right(collection.fromChunk(builder.result())) } + private[optics] def makeNonEmptySeqTraversalSet[S, A]( + collection: Schema.NonEmptySequence[S, A, _] + ): Chunk[A] => S => Either[(OpticFailure, S), S] = { (piece: Chunk[A]) => (whole: S) => + val builder = ChunkBuilder.make[A]() + val leftIterator = collection.toChunk(whole).iterator + val rightIterator = piece.iterator + while (leftIterator.hasNext && rightIterator.hasNext) { + val _ = leftIterator.next() + builder += rightIterator.next() + } + while (leftIterator.hasNext) { + builder += leftIterator.next() + } + Right(collection.fromChunk(builder.result())) + } private[optics] def makeMapTraversalGet[K, V](whole: Map[K, V]): Either[(OpticFailure, Map[K, V]), Chunk[(K, V)]] = Right(Chunk.fromIterable(whole)) + private[optics] def makeMapTraversalGet[K, V]( + whole: NonEmptyMap[K, V] + ): Either[(OpticFailure, NonEmptyMap[K, V]), Chunk[(K, V)]] = + Right(Chunk.fromIterable(whole.toMap)) + private[optics] def makeMapTraversalSet[K, V] : Chunk[(K, V)] => Map[K, V] => Either[(OpticFailure, Map[K, V]), Map[K, V]] = { (piece: Chunk[(K, V)]) => (whole: Map[K, V]) => Right(whole ++ piece.toList) } + private[optics] def makeNonEmptyMapTraversalSet[K, V] + : Chunk[(K, V)] => NonEmptyMap[K, V] => Either[(OpticFailure, NonEmptyMap[K, V]), NonEmptyMap[K, V]] = { + (piece: Chunk[(K, V)]) => (whole: NonEmptyMap[K, V]) => + Right(whole ++ piece.toList) + } private[optics] def makeSetTraversalGet[A](whole: Set[A]): Either[(OpticFailure, Set[A]), Chunk[A]] = Right(Chunk.fromIterable(whole)) diff --git a/zio-schema/shared/src/main/scala/zio/schema/Schema.scala b/zio-schema/shared/src/main/scala/zio/schema/Schema.scala index ec792b7ff..b8601a57e 100644 --- a/zio-schema/shared/src/main/scala/zio/schema/Schema.scala +++ b/zio-schema/shared/src/main/scala/zio/schema/Schema.scala @@ -6,11 +6,12 @@ import java.time.temporal.ChronoUnit import scala.annotation.tailrec import scala.collection.immutable.ListMap +import zio.prelude.NonEmptySet import zio.schema.annotation._ import zio.schema.internal.SourceLocation import zio.schema.meta._ import zio.schema.validation._ -import zio.{ Chunk, Unsafe } +import zio.{ Chunk, NonEmptyChunk, Unsafe, prelude } /** * A `Schema[A]` describes the structure of some data type `A`, in terms of case classes, @@ -196,6 +197,8 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality { schema match { case Sequence(schema, _, toChunk, _, _) => toChunk(value).flatMap(value => loop(value, schema)) + case nes @ NonEmptySequence(schema, _, _, _, _) => + nes.toChunk(value).flatMap(value => loop(value, schema)) case Transform(schema, _, g, _, _) => g(value) match { case Right(value) => loop(value, schema) @@ -211,6 +214,10 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality { loop(tuple.extract1(value), left) ++ loop(tuple.extract2(value), right) case l @ Lazy(_) => loop(value, l.schema) + case Schema.NonEmptyMap(ks, vs, _) => + Chunk.fromIterable(value.toMap.keys).flatMap(loop(_, ks)) ++ Chunk + .fromIterable(value.values) + .flatMap(loop(_, vs)) case Schema.Map(ks, vs, _) => Chunk.fromIterable(value.keys).flatMap(loop(_, ks)) ++ Chunk.fromIterable(value.values).flatMap(loop(_, vs)) case set @ Schema.Set(as, _) => @@ -299,6 +306,30 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality { implicit def chunk[A](implicit schemaA: Schema[A]): Schema[Chunk[A]] = Schema.Sequence[Chunk[A], A, String](schemaA, identity, identity, Chunk.empty, "Chunk") + implicit def nonEmptyChunk[A](implicit schemaA: Schema[A]): Schema[NonEmptyChunk[A]] = + Schema.NonEmptySequence[NonEmptyChunk[A], A, String]( + schemaA, + NonEmptyChunk.fromChunk, + _.toChunk, + Chunk.empty, + "NonEmptyChunk" + ) + + implicit def nonEmptySet[A](implicit schemaA: Schema[A]): Schema[NonEmptySet[A]] = + Schema.NonEmptySequence[NonEmptySet[A], A, String]( + schemaA, + chunk => NonEmptySet.fromSetOption(chunk.toSet), + _.toNonEmptyChunk.toChunk, + Chunk.empty, + "NonEmptySet" + ) + + implicit def nonEmptyMap[K, V]( + implicit keySchema: Schema[K], + valueSchema: Schema[V] + ): Schema[prelude.NonEmptyMap[K, V]] = + Schema.NonEmptyMap[K, V](keySchema, valueSchema, Chunk.empty) + implicit def map[K, V]( implicit keySchema: Schema[K], valueSchema: Schema[V] @@ -788,6 +819,62 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality { b.makeTraversal(self, keySchema <*> valueSchema) } + final case class NonEmptyMap[K, V]( + keySchema: Schema[K], + valueSchema: Schema[V], + override val annotations: Chunk[Any] = Chunk.empty + ) extends Collection[prelude.NonEmptyMap[K, V], (K, V)] { + self => + override type Accessors[Lens[_, _, _], Prism[_, _, _], Traversal[_, _]] = + Traversal[prelude.NonEmptyMap[K, V], (K, V)] + + override def annotate(annotation: Any): NonEmptyMap[K, V] = + copy(annotations = (annotations :+ annotation).distinct) + + override def defaultValue: scala.Either[String, prelude.NonEmptyMap[K, V]] = + keySchema.defaultValue.flatMap( + defaultKey => valueSchema.defaultValue.map(defaultValue => prelude.NonEmptyMap(defaultKey -> defaultValue)) + ) + + def fromChunk(chunk: Chunk[(K, V)]): prelude.NonEmptyMap[K, V] = + fromChunkOption(chunk).getOrElse(throw new IllegalArgumentException("NonEmptyMap cannot be empty")) + + def fromChunkOption(chunk: Chunk[(K, V)]): Option[prelude.NonEmptyMap[K, V]] = + NonEmptyChunk.fromChunk(chunk).map(prelude.NonEmptyMap.fromNonEmptyChunk) + + def toChunk(map: prelude.NonEmptyMap[K, V]): Chunk[(K, V)] = + Chunk.fromIterable(map.toList) + + override def makeAccessors(b: AccessorBuilder): b.Traversal[prelude.NonEmptyMap[K, V], (K, V)] = + b.makeTraversal(self, keySchema <*> valueSchema) + } + + final case class NonEmptySequence[Col, Elm, I]( + elementSchema: Schema[Elm], + fromChunkOption: Chunk[Elm] => Option[Col], + toChunk: Col => Chunk[Elm], + override val annotations: Chunk[Any] = Chunk.empty, + identity: I + ) extends Collection[Col, Elm] { + self => + override type Accessors[Lens[_, _, _], Prism[_, _, _], Traversal[_, _]] = Traversal[Col, Elm] + + def fromChunk(chunk: Chunk[Elm]): Col = + fromChunkOption(chunk).getOrElse( + throw new IllegalArgumentException(s"NonEmptySequence $identity cannot be empty") + ) + + override def annotate(annotation: Any): NonEmptySequence[Col, Elm, I] = + copy(annotations = (annotations :+ annotation).distinct) + + override def defaultValue: scala.util.Either[String, Col] = + elementSchema.defaultValue.map((fromChunk _).compose(Chunk(_))) + + override def makeAccessors(b: AccessorBuilder): b.Traversal[Col, Elm] = b.makeTraversal(self, elementSchema) + + override def toString: String = s"NonEmptySequence($elementSchema, $identity)" + } + final case class Set[A](elementSchema: Schema[A], override val annotations: Chunk[Any] = Chunk.empty) extends Collection[scala.collection.immutable.Set[A], A] { self =>