Skip to content

Commit

Permalink
Non empty collection schemas (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Aug 7, 2024
1 parent 5a276f7 commit 94bf29d
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ object DynamicValueGen {
case Schema.Enum22(_, case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22, _) => anyDynamicValueOfEnum(Chunk(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22))
case Schema.EnumN(_, cases, _) => anyDynamicValueOfEnum(Chunk.fromIterable(cases.toSeq))
case Schema.Sequence(schema, _, _, _, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.NonEmptySequence(schema, _, _, _, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.Map(ks, vs, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.NonEmptyMap(ks, vs, _) => Gen.chunkOfBounded(1, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.Set(schema, _) => Gen.setOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.SetValue(_))
case Schema.Optional(schema, _) => Gen.oneOf(anyDynamicSomeValueOfSchema(schema), Gen.const(DynamicValue.NoneValue))
case Schema.Tuple2(left, right, _) => anyDynamicTupleValue(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,8 @@ import zio.json.JsonCodec._
import zio.json.JsonDecoder.{ JsonError, UnsafeJson }
import zio.json.ast.Json
import zio.json.internal.{ Lexer, RecordingReader, RetractReader, StringMatrix, WithRecordingReader, Write }
import zio.json.{
JsonCodec => ZJsonCodec,
JsonDecoder => ZJsonDecoder,
JsonEncoder => ZJsonEncoder,
JsonFieldDecoder,
JsonFieldEncoder
}
import zio.json.{JsonCodec => ZJsonCodec, JsonDecoder => ZJsonDecoder, JsonEncoder => ZJsonEncoder, JsonFieldDecoder, JsonFieldEncoder}
import zio.prelude.NonEmptyMap
import zio.schema._
import zio.schema.annotation._
import zio.schema.codec.DecodeError.ReadError
Expand Down Expand Up @@ -182,9 +177,11 @@ object JsonCodec {
//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaEncoder[A](schema: Schema[A], cfg: Config, discriminatorTuple: DiscriminatorTuple = Chunk.empty): ZJsonEncoder[A] =
schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).encoder
case Schema.Sequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g)
case Schema.NonEmptySequence(schema, _, g, _, _) => ZJsonEncoder.chunk(schemaEncoder(schema, cfg, discriminatorTuple)).contramap(g)
case Schema.Map(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg)
case Schema.NonEmptyMap(ks, vs, _) => mapEncoder(ks, vs, discriminatorTuple, cfg).contramap[NonEmptyMap[Any, Any]](_.toMap.asInstanceOf[Map[Any, Any]]).asInstanceOf[ZJsonEncoder[A]]
case Schema.Set(s, _) =>
ZJsonEncoder.chunk(schemaEncoder(s, cfg, discriminatorTuple)).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, a, _) => transformEncoder(a.foldLeft(c)((s, a) => s.annotate(a)), g, cfg)
Expand Down Expand Up @@ -544,18 +541,20 @@ object JsonCodec {

//scalafmt: { maxColumn = 400, optIn.configStyleArguments = false }
private[codec] def schemaDecoder[A](schema: Schema[A], discriminator: Int = -1): ZJsonDecoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => option(schemaDecoder(codec, discriminator))
case Schema.Tuple2(left, right, _) => ZJsonDecoder.tuple2(schemaDecoder(left, -1), schemaDecoder(right, -1))
case Schema.Transform(c, f, _, a, _) => schemaDecoder(a.foldLeft(c)((s, a) => s.annotate(a)), discriminator).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(f)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s, -1)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk, schema.annotations.contains(rejectExtraFields()))
case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left, -1), schemaDecoder(right, -1))
case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s)
case l @ Schema.Lazy(_) => schemaDecoder(l.schema, discriminator)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType).decoder
case Schema.Optional(codec, _) => option(schemaDecoder(codec, discriminator))
case Schema.Tuple2(left, right, _) => ZJsonDecoder.tuple2(schemaDecoder(left, -1), schemaDecoder(right, -1))
case Schema.Transform(c, f, _, a, _) => schemaDecoder(a.foldLeft(c)((s, a) => s.annotate(a)), discriminator).mapOrFail(f)
case Schema.Sequence(codec, f, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(f)
case s @ Schema.NonEmptySequence(codec, _, _, _, _) => ZJsonDecoder.chunk(schemaDecoder(codec, -1)).map(s.fromChunk)
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case Schema.NonEmptyMap(ks, vs, _) => mapDecoder(ks, vs).mapOrFail(m => NonEmptyMap.fromMapOption(m).toRight("NonEmptyMap expected"))
case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s, -1)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk, schema.annotations.contains(rejectExtraFields()))
case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left, -1), schemaDecoder(right, -1))
case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s)
case l @ Schema.Lazy(_) => schemaDecoder(l.schema, discriminator)
//case Schema.Meta(_, _) => astDecoder
case s @ Schema.CaseClass0(_, _, _) => caseClass0Decoder(discriminator, s)
case s @ Schema.CaseClass1(_, _, _, _) => caseClass1Decoder(discriminator, s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zio.schema.optics
import scala.collection.immutable.ListMap

import zio.optics._
import zio.prelude.NonEmptyMap
import zio.schema._
import zio.{ Chunk, ChunkBuilder }

Expand Down Expand Up @@ -44,14 +45,24 @@ object ZioOpticsBuilder extends AccessorBuilder {
collection match {
case seq @ Schema.Sequence(_, _, _, _, _) =>
ZTraversal[S, S, A, A](
ZioOpticsBuilder.makeSeqTraversalGet(seq),
ZioOpticsBuilder.makeSeqTraversalGet(seq.toChunk),
ZioOpticsBuilder.makeSeqTraversalSet(seq)
)
case seq @ Schema.NonEmptySequence(_, _, _, _, _) =>
ZTraversal[S, S, A, A](
ZioOpticsBuilder.makeSeqTraversalGet(seq.toChunk),
ZioOpticsBuilder.makeNonEmptySeqTraversalSet(seq)
)
case Schema.Map(_: Schema[k], _: Schema[v], _) =>
ZTraversal(
ZioOpticsBuilder.makeMapTraversalGet[k, v],
ZioOpticsBuilder.makeMapTraversalSet[k, v]
)
case Schema.NonEmptyMap(_: Schema[k], _: Schema[v], _) =>
ZTraversal(
ZioOpticsBuilder.makeMapTraversalGet[k, v],
ZioOpticsBuilder.makeNonEmptyMapTraversalSet[k, v]
)
case Schema.Set(_, _) =>
ZTraversal(
ZioOpticsBuilder.makeSetTraversalGet[A],
Expand Down Expand Up @@ -103,9 +114,9 @@ object ZioOpticsBuilder extends AccessorBuilder {
}

private[optics] def makeSeqTraversalGet[S, A](
collection: Schema.Sequence[S, A, _]
toChunk: S => Chunk[A]
): S => Either[(OpticFailure, S), Chunk[A]] = { (whole: S) =>
Right(collection.toChunk(whole))
Right(toChunk(whole))
}

private[optics] def makeSeqTraversalSet[S, A](
Expand All @@ -123,15 +134,40 @@ object ZioOpticsBuilder extends AccessorBuilder {
}
Right(collection.fromChunk(builder.result()))
}
private[optics] def makeNonEmptySeqTraversalSet[S, A](
collection: Schema.NonEmptySequence[S, A, _]
): Chunk[A] => S => Either[(OpticFailure, S), S] = { (piece: Chunk[A]) => (whole: S) =>
val builder = ChunkBuilder.make[A]()
val leftIterator = collection.toChunk(whole).iterator
val rightIterator = piece.iterator
while (leftIterator.hasNext && rightIterator.hasNext) {
val _ = leftIterator.next()
builder += rightIterator.next()
}
while (leftIterator.hasNext) {
builder += leftIterator.next()
}
Right(collection.fromChunk(builder.result()))
}

private[optics] def makeMapTraversalGet[K, V](whole: Map[K, V]): Either[(OpticFailure, Map[K, V]), Chunk[(K, V)]] =
Right(Chunk.fromIterable(whole))

private[optics] def makeMapTraversalGet[K, V](
whole: NonEmptyMap[K, V]
): Either[(OpticFailure, NonEmptyMap[K, V]), Chunk[(K, V)]] =
Right(Chunk.fromIterable(whole.toMap))

private[optics] def makeMapTraversalSet[K, V]
: Chunk[(K, V)] => Map[K, V] => Either[(OpticFailure, Map[K, V]), Map[K, V]] = {
(piece: Chunk[(K, V)]) => (whole: Map[K, V]) =>
Right(whole ++ piece.toList)
}
private[optics] def makeNonEmptyMapTraversalSet[K, V]
: Chunk[(K, V)] => NonEmptyMap[K, V] => Either[(OpticFailure, NonEmptyMap[K, V]), NonEmptyMap[K, V]] = {
(piece: Chunk[(K, V)]) => (whole: NonEmptyMap[K, V]) =>
Right(whole ++ piece.toList)
}

private[optics] def makeSetTraversalGet[A](whole: Set[A]): Either[(OpticFailure, Set[A]), Chunk[A]] =
Right(Chunk.fromIterable(whole))
Expand Down
89 changes: 88 additions & 1 deletion zio-schema/shared/src/main/scala/zio/schema/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import java.time.temporal.ChronoUnit
import scala.annotation.tailrec
import scala.collection.immutable.ListMap

import zio.prelude.NonEmptySet
import zio.schema.annotation._
import zio.schema.internal.SourceLocation
import zio.schema.meta._
import zio.schema.validation._
import zio.{ Chunk, Unsafe }
import zio.{ Chunk, NonEmptyChunk, Unsafe, prelude }

/**
* A `Schema[A]` describes the structure of some data type `A`, in terms of case classes,
Expand Down Expand Up @@ -196,6 +197,8 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality {
schema match {
case Sequence(schema, _, toChunk, _, _) =>
toChunk(value).flatMap(value => loop(value, schema))
case nes @ NonEmptySequence(schema, _, _, _, _) =>
nes.toChunk(value).flatMap(value => loop(value, schema))
case Transform(schema, _, g, _, _) =>
g(value) match {
case Right(value) => loop(value, schema)
Expand All @@ -211,6 +214,10 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality {
loop(tuple.extract1(value), left) ++ loop(tuple.extract2(value), right)
case l @ Lazy(_) =>
loop(value, l.schema)
case Schema.NonEmptyMap(ks, vs, _) =>
Chunk.fromIterable(value.toMap.keys).flatMap(loop(_, ks)) ++ Chunk
.fromIterable(value.values)
.flatMap(loop(_, vs))
case Schema.Map(ks, vs, _) =>
Chunk.fromIterable(value.keys).flatMap(loop(_, ks)) ++ Chunk.fromIterable(value.values).flatMap(loop(_, vs))
case set @ Schema.Set(as, _) =>
Expand Down Expand Up @@ -299,6 +306,30 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality {
implicit def chunk[A](implicit schemaA: Schema[A]): Schema[Chunk[A]] =
Schema.Sequence[Chunk[A], A, String](schemaA, identity, identity, Chunk.empty, "Chunk")

implicit def nonEmptyChunk[A](implicit schemaA: Schema[A]): Schema[NonEmptyChunk[A]] =
Schema.NonEmptySequence[NonEmptyChunk[A], A, String](
schemaA,
NonEmptyChunk.fromChunk,
_.toChunk,
Chunk.empty,
"NonEmptyChunk"
)

implicit def nonEmptySet[A](implicit schemaA: Schema[A]): Schema[NonEmptySet[A]] =
Schema.NonEmptySequence[NonEmptySet[A], A, String](
schemaA,
chunk => NonEmptySet.fromSetOption(chunk.toSet),
_.toNonEmptyChunk.toChunk,
Chunk.empty,
"NonEmptySet"
)

implicit def nonEmptyMap[K, V](
implicit keySchema: Schema[K],
valueSchema: Schema[V]
): Schema[prelude.NonEmptyMap[K, V]] =
Schema.NonEmptyMap[K, V](keySchema, valueSchema, Chunk.empty)

implicit def map[K, V](
implicit keySchema: Schema[K],
valueSchema: Schema[V]
Expand Down Expand Up @@ -788,6 +819,62 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality {
b.makeTraversal(self, keySchema <*> valueSchema)
}

final case class NonEmptyMap[K, V](
keySchema: Schema[K],
valueSchema: Schema[V],
override val annotations: Chunk[Any] = Chunk.empty
) extends Collection[prelude.NonEmptyMap[K, V], (K, V)] {
self =>
override type Accessors[Lens[_, _, _], Prism[_, _, _], Traversal[_, _]] =
Traversal[prelude.NonEmptyMap[K, V], (K, V)]

override def annotate(annotation: Any): NonEmptyMap[K, V] =
copy(annotations = (annotations :+ annotation).distinct)

override def defaultValue: scala.Either[String, prelude.NonEmptyMap[K, V]] =
keySchema.defaultValue.flatMap(
defaultKey => valueSchema.defaultValue.map(defaultValue => prelude.NonEmptyMap(defaultKey -> defaultValue))
)

def fromChunk(chunk: Chunk[(K, V)]): prelude.NonEmptyMap[K, V] =
fromChunkOption(chunk).getOrElse(throw new IllegalArgumentException("NonEmptyMap cannot be empty"))

def fromChunkOption(chunk: Chunk[(K, V)]): Option[prelude.NonEmptyMap[K, V]] =
NonEmptyChunk.fromChunk(chunk).map(prelude.NonEmptyMap.fromNonEmptyChunk)

def toChunk(map: prelude.NonEmptyMap[K, V]): Chunk[(K, V)] =
Chunk.fromIterable(map.toList)

override def makeAccessors(b: AccessorBuilder): b.Traversal[prelude.NonEmptyMap[K, V], (K, V)] =
b.makeTraversal(self, keySchema <*> valueSchema)
}

final case class NonEmptySequence[Col, Elm, I](
elementSchema: Schema[Elm],
fromChunkOption: Chunk[Elm] => Option[Col],
toChunk: Col => Chunk[Elm],
override val annotations: Chunk[Any] = Chunk.empty,
identity: I
) extends Collection[Col, Elm] {
self =>
override type Accessors[Lens[_, _, _], Prism[_, _, _], Traversal[_, _]] = Traversal[Col, Elm]

def fromChunk(chunk: Chunk[Elm]): Col =
fromChunkOption(chunk).getOrElse(
throw new IllegalArgumentException(s"NonEmptySequence $identity cannot be empty")
)

override def annotate(annotation: Any): NonEmptySequence[Col, Elm, I] =
copy(annotations = (annotations :+ annotation).distinct)

override def defaultValue: scala.util.Either[String, Col] =
elementSchema.defaultValue.map((fromChunk _).compose(Chunk(_)))

override def makeAccessors(b: AccessorBuilder): b.Traversal[Col, Elm] = b.makeTraversal(self, elementSchema)

override def toString: String = s"NonEmptySequence($elementSchema, $identity)"
}

final case class Set[A](elementSchema: Schema[A], override val annotations: Chunk[Any] = Chunk.empty)
extends Collection[scala.collection.immutable.Set[A], A] {
self =>
Expand Down

0 comments on commit 94bf29d

Please sign in to comment.